#include "ast/ast.hpp"
#include "ast/visitor.hpp"
#include "ast/typechecker.hpp"
#include "parser/parser.hpp"
#include "runtime/object.hpp"
#include "env.hpp"
#include "type/type.hpp"
#include "util/bufferedInputStream.hpp"

#include <memory>

void eval(BufferedInputStream* stream) {
    Lexer lex(stream);
    Parser parser(&lex);
    parser.eval();
}

Parser::Parser(Lexer* lex) {
    _lex = lex;
    _cur_token = NULL;
}

Token* Parser::get_token() {
    if (_cur_token == NULL) {
        _cur_token = _lex->next();
#ifdef DEBUG
        _cur_token->print();
#endif
    }

    return _cur_token;
}

void Parser::consume() {
    delete _cur_token;
    _cur_token = NULL;
}

Node* Parser::trails(Node* left) {
    consume();
    Node* param = expression();
    match(get_token(), TokenType::RIGHT_PAR);
    Node* n = new CallNode(left, param);

    Token* t = get_token();
    if (t->_tt == TokenType::LEFT_PAR) {
        n = trails(n);
    }

    return n;
}

Node* Parser::simple_stmt() {
    Node* atom = expression();
    
    Token* t = get_token();
    
    if (t->_tt == TokenType::ASN) {
        consume();
        return new AssignNode(atom, expression());
    }
    else if (t->_tt == TokenType::ENDL) {
        consume();
        return atom;
    }
    else if (t->_tt == TokenType::EQUAL) {
        return test(atom);
    }
    
    return atom;
}

/*
 * if_expr -> expression IF test ELSE expression
 */
Node* Parser::if_expr(Node* left) {
    match(get_token(), TokenType::IF);
    Node* condition = or_test();
    match(get_token(), TokenType::ELSE);
    Node* right = expression();

    return new IfNode(condition, left, right);
}

/*
 * test -> expression '==' expression
 */
Node* Parser::test(Node* left) {
    match(get_token(), TokenType::EQUAL);
    Node* right = expression();

    return new CmpNode(N_EQUAL, left, right);
}

/*
 * expression -> lambda_def
 *              | or_test
 *              | or_test IF or_test ELSE or_test
 */
Node* Parser::expression() {
    Token* left_most = get_token();
    if (left_most->_tt == TokenType::LAMBDA) {
        return lambda_def();
    }

    Node* a = or_test();

    Token* t = get_token();
    if (t->_tt == TokenType::IF) {
        a = if_expr(a);
    }

    return a;
}

/*
 * or_test -> and_test
 *          | or_test OR and_test
 */
Node* Parser::or_test() {
    Node* a = and_test();
    Token* op = get_token();
    while (op != NULL && op->_tt == TokenType::LOGIC_OR) {
        consume();
        Node* b = and_test();
        a = new LogicOrNode(a, b);
        op = get_token();
    }

    return a;
}

/*
 * and_test -> not_test
 *          | not_test "&&" and_test
 */
Node* Parser::and_test() {
    Node* a = not_test();
    Token* op = get_token();
    while (op != NULL && op->_tt == TokenType::LOGIC_AND) {
        consume();
        Node* b = and_test();
        a = new LogicAndNode(a, b);
        op = get_token();
    }

    return a;
}

/*
 * not_test -> "!" not_test
 *           | comparison
 */
Node* Parser::not_test() {
    Token* t = get_token();

    if (t->_tt == TokenType::LOGIC_NOT) {
        consume();
        return new LogicNotNode(not_test());
    }

    return comparison();
}

Node* Parser::comparison() {
    Node* a = or_expr();

    Token* t = get_token();
    if (t->_tt == TokenType::EQUAL) {
        consume();
        Node* b = or_expr();
        a = new CmpNode(N_EQUAL, a, b);
    }

    return a;
}

Node* Parser::or_expr() {
    Node* a = xor_expr();
    Token* op = get_token();
    while (op != NULL && op->_tt == TokenType::BIT_OR) {
        consume();
        Node* b = xor_expr();
        a = new BitOrNode(a, b);
        op = get_token();
    }

    return a;
}

Node* Parser::xor_expr() {
    Node* a = and_expr();
    Token* op = get_token();
    while (op != NULL && op->_tt == TokenType::BIT_XOR) {
        consume();
        Node* b = and_expr();
        a = new BitXorNode(a, b);
        op = get_token();
    }

    return a;
}

Node* Parser::and_expr() {
    Node* a = shift_expr();
    Token* op = get_token();
    while (op != NULL && op->_tt == TokenType::BIT_AND) {
        consume();
        Node* b = shift_expr();
        a = new BitAndNode(a, b);
        op = get_token();
    }

    return a;
}

Node* Parser::shift_expr() {
    Node* a = arith_expr();
    Token* op = get_token();
    while (op != NULL && 
            (op->_tt == TokenType::LEFT_SHIFT ||
             op->_tt == TokenType::RIGHT_SHIFT)) {
        TokenType op_tt = op->_tt;
        consume();
        Node* b = arith_expr();
        if (op_tt == TokenType::LEFT_SHIFT) {
            a = new LeftShiftNode(a, b);
        }
        else {
            a = new RightShiftNode(a, b);
        }
        op = get_token();
    }

    return a;
}

Node* Parser::arith_expr() {
    Node* a = term();
    Token* op = get_token();
    while (op != NULL &&
        (op->_tt == TokenType::PLUS || op->_tt == TokenType::MINUS)) {
        TokenType op_tt = op->_tt;
        consume();
        Node* b = term();
        if (op_tt == TokenType::PLUS) {
            a = new AddNode(a, b);
        } else {
            a = new SubNode(a, b);
        }

        op = get_token();
    }

    return a;
}

Node* Parser::term() {
    Node* a = factor();
    Token* op = get_token();
    while (op != NULL &&
        (op->_tt == TokenType::MULT || op->_tt == TokenType::DIV)) {
        TokenType op_tt = op->_tt;
        consume();
        Node* b = factor();
        if (op_tt == TokenType::MULT) {
            a = new MulNode(a, b);
        } else {
            a = new DivNode(a, b);
        }

        op = get_token();
    }

    return a;
}

/*
 * factor -> "-" factor
 *          | atom
 *          | atom trails
 */
Node* Parser::factor() {
    Token* data = get_token();
    if (data->_tt == TokenType::MINUS) {
        consume();
        return new SubNode(new ConstInt(new IntObject(0)), 
            factor());
    }

    Node* a = atom();
    Token* t = get_token();

    if (t->_tt == TokenType::LEFT_PAR) {
        return trails(a);
    }
    else {
        return a;
    }
}

Node* Parser::atom() {
    Token* data = get_token();

    if (data->_tt == TokenType::INT) {
        IntObject* obj = new IntObject(stoi(data));
        consume();
        return new ConstInt(obj);
    }
    else if (data->_tt == TokenType::STRING) {
        // ignore first and last charactor : '"'
        StringObject* obj = new StringObject(data->_value + 1, data->_length - 2);
        consume();
        return new ConstString(obj);
    }
    else if (data->_tt == TokenType::CHAR) {
        CharObject* obj = new CharObject(*data->_value);
        consume();
        return new ConstChar(obj);
    }
    else if (data->_tt == TokenType::BOOL_TRUE) {
        consume();
        return new ConstBool(true);
    }
    else if (data->_tt == TokenType::BOOL_FALSE) {
        consume();
        return new ConstBool(false);
    }
    else if (data->_tt == TokenType::NAME) {
        VarNode* var = new VarNode(data);
        consume();
        return var;
    }
    else if (data->_tt == TokenType::LEFT_PAR) {
        consume();
        Node* a = expression();
        match(get_token(), TokenType::RIGHT_PAR);

        return a;
    }

    return NULL;
}

void Parser::match(Token* t, TokenType tt) {
    if (t->_tt == tt) {
        consume();
    } 
    else {
        printf("%d:%d parse error: expected %s, got %s\n",
            t->sourceRange().begin()._line, t->sourceRange().begin()._column,
            toString(tt), toString(t->_tt));
        exit(1);
    }
}

/*
 * lambda_def = "$" NAME ":" type "=>" expression
 *            | "$" NAME ":" type "=>" suite
 */
Node* Parser::lambda_def() {
    match(get_token(), TokenType::LAMBDA);
    Token* param = get_token();
    VarDefNode* n = new VarDefNode(param);
    match(param, TokenType::NAME);
    match(get_token(), TokenType::COLON);
    n->set_type(type_stmt());
    match(get_token(), TokenType::LAM_DEF);

    Token* t = get_token();
    if (t->_tt == TokenType::LEFT_BRACKET) {
        return new LambdaDef(n, suite());
    }
    
    return new LambdaDef(n, simple_stmt());
}

Node* Parser::suite() {
    match(get_token(), TokenType::LEFT_BRACKET);
    Node* t = statements(new ListNode());
    match(get_token(), TokenType::RIGHT_BRACKET);

    return t;
}

Node* Parser::statement() {
    Token* t = get_token();
    if (t->_tt == TokenType::ENDL) {        
        // empty statement, do nothing.
        consume();
        return NULL;
    }
    else if (t->_tt == TokenType::VAR) {
        return declare_stmt();
    }
    else if (t->_tt == TokenType::TYPE) {
        return typedef_stmt();
    }
    else if (t->_tt == TokenType::IF) {
        return if_stmt();
    }
    else {
        return simple_stmt();
    }
}

/*
 * stmts -> stmt stmts
 *       | ENDMARKER
 */
Node* Parser::statements(ListNode* nodelist) {
    Token* t = get_token();
    if (t->_tt == TokenType::ENDMARKER) {
        consume();
        return nodelist;
    }

    if (t->_tt == TokenType::RIGHT_BRACKET) {
        return nodelist;
    }

    Node* stmt = statement();
    if (stmt != NULL) {
        nodelist->add(stmt);
    }
    return statements(nodelist);
}

/*
 * declare -> "var" NAME ":" type
 *          | "var" NAME ":" type "=" expression
 *          | "var" NAME "=" expression
 *
 * In the third case, type of NAME will be decided by the type
 * of expression(aka. its init value)
 */
Node* Parser::declare_stmt() {
    match(get_token(), TokenType::VAR);

    Token* t = get_token();
    VarDefNode* result = NULL;
    if (t->_tt == TokenType::NAME) {
        result = new VarDefNode(get_token());
    }
    else {
        printf("expect variable name, but get %s\n", toString(t->_tt));
        exit(1);
    }
    consume();

    if (get_token()->_tt != TokenType::COLON &&
            get_token()->_tt != TokenType::ASN) {
        printf("declare varialbe '%s' without type.\n", result->name());
        exit(1);
    }

    if (get_token()->_tt == TokenType::COLON) {
        consume();
        result->set_type(type_stmt());
    }

    if (get_token()->_tt == TokenType::ASN) {
        consume(); // eat '='
        result->set_init(expression());
    }

    return result;
}

/*
 * type -> type_atom type_trailer
 *        | type_atom type_trailer "->" type
 */
Type* Parser::type_stmt() {
    Type* atom = type_atom();
    Token* t = get_token();

    if (t->_tt == TokenType::LT) {
        atom = new TypeFunctionApply(atom, type_trailer());
    }

    // Notion: get_token does not advance parser cursor.
    t = get_token();
    if (t->_tt == TokenType::ARROW) {
        consume();
        ArrowType* at = new ArrowType(atom, type_stmt());
        return at;
    }
    else {
        return atom;
    }
}

Type* Parser::type_atom() {
    Token* t = get_token();
    if (t->_tt == TokenType::TYPE_INT) {
        consume();
        return IntType::get_instance();
    }
    else if (t->_tt == TokenType::TYPE_BOOL) {
        consume();
        return BoolType::get_instance();
    }
    else if (t->_tt == TokenType::TYPE_STRING) {
        consume();
        return StringType::get_instance();
    }
    else if (t->_tt == TokenType::TYPE_DOUBLE) {
        consume();
        return DoubleType::get_instance();
    }
    else if (t->_tt == TokenType::TYPE_CHAR) {
        consume();
        return CharType::get_instance();
    }
    else if (t->_tt == TokenType::TYPE_ANY) {
        consume();
        return AnyType::get_instance();
    }
    else if (t->_tt == TokenType::LEFT_PAR) {
        consume();
        Type* ttype = type_stmt();
        match(get_token(), TokenType::RIGHT_PAR);

        return ttype;
    }
    else if (t->_tt == TokenType::NAME) {
        Type* ttype = new UserDefType(t);
        consume();
        return ttype;
    }
    else {
        printf("unrecognized type: %s", toString(t->_tt));
        exit(1);
    }

    return NULL;
}

/*
 * type_trailer -> "<" type(, type)* ">"
 */
TypeArgs* Parser::type_trailer() {
    match(get_token(), TokenType::LT);
    Type* t = type_stmt();
    std::vector<Type*> vec;
    vec.push_back(t);

    Token* token = get_token();
    while (token->_tt == TokenType::COMMA) {
        consume();
        t = type_stmt();
        vec.push_back(t);
    }

    match(get_token(), TokenType::GT);

    TypeArgs* args = new TypeArgs(NULL, 0);
    args->init_length(vec.size());
    int i = 0;
    for (auto it = vec.begin(); it != vec.end(); it++) {
        args->set_arg(i++, *it);
    }

    return args;
}

/*
 * tp_args -> NAME(, NAME)*
 */
TypeArgs* Parser::tp_args() {
    TypeArgs* args = new TypeArgs(NULL, 0);

    Token* t = get_token();

    if (t->_tt == TokenType::GT) {
        consume();
        printf("warning: Empty type variable list\n");
        return args;
    }

    std::vector<TypeVar*> vec;
    if (t->_tt == TokenType::NAME) {
        vec.push_back(new TypeVar(t));
        consume();
    }

    t = get_token();
    while (t->_tt == TokenType::COMMA) {
        consume();
        t = get_token();

        if (t->_tt == TokenType::NAME) {
            vec.push_back(new TypeVar(t));
            consume();
        }
    }

    args->init_length(vec.size());
    int i = 0;
    for (auto it = vec.begin(); it != vec.end(); it++) {
        args->set_arg(i++, *it);
    }

    return args;
}

/**
 * typedef_stmt -> "type" NAME "=" type_stmt
 *              | "type" NAME "<" poly_var_list ">" "=" type_stmt
 */
Node* Parser::typedef_stmt() {
    match(get_token(), TokenType::TYPE);
    Token* t = get_token();
    TypeDefNode* n = new TypeDefNode(t);
    match(get_token(), TokenType::NAME);

    TypeArgs* args = NULL;
    t = get_token();
    if (t->_tt == TokenType::LT) {
        consume();
        args = tp_args();
        match(get_token(), TokenType::GT);
    }

    match(get_token(), TokenType::ASN);
    Type* tp = type_stmt();
    if (args == NULL) {
        n->set_type(tp);
    }
    else {
        TypeFunction* tp_func = new TypeFunction(args, tp);
        n->set_type(tp_func);
    }
    return n;
}

Node* Parser::if_stmt() {
    match(get_token(), TokenType::IF);
    match(get_token(), TokenType::LEFT_PAR);
    Node* cond = or_test();
    match(get_token(), TokenType::RIGHT_PAR);
    Node* then_block = suite();
    match(get_token(), TokenType::ELSE);
    Node* else_block = suite();
    return new IfNode(cond, then_block, else_block);
}

int Parser::stoi(Token* data) {
    int value = 0;
    for (int i = 0; i < data->_length; i++) {
        value = value * 10 + data->_value[i] - '0';
    }

    return value;
}

void Parser::eval() {
    std::unique_ptr<ListNode> root(new ListNode());
    statements(root.get());
#ifdef DEBUG
    Dumper dumper;
    dumper.visit_node(root.get());
#endif
    {
        TypeChecker checker;
        checker.visit_node(root.get());
        if (!checker.status()) {
            return;
        }
    }
    Evaluator evaluator;
    evaluator.visit_node(root.get());
}
